{# ----------------------------------------------------------------------------
 # SymForce - Copyright 2025, Skydio, Inc.
 # This source code is under the Apache 2.0 license found in the LICENSE file.
 # ---------------------------------------------------------------------------- #}
#pragma once

#include <cstdint>

#include <cuda_runtime.h>

#include "shared_indices.h"
#include "solver_params.h"

namespace caspar {

class {{ solver.struct_name }} {
 public:
  /**
   * Base constructor.
   *
   * @param params: The params to use for the solver
  {% for thing in solver.size_contributors %}
   * @param {{num_arg_key(thing)}} the maximum number of {{name_key(thing)}}s
  {% endfor %}
   */
  {{ solver.struct_name }}(
      const SolverParams &params,
      {% for thing in solver.size_contributors %}
      size_t {{num_arg_key(thing)}}{{ ", " if not loop.last else "" }}
      {% endfor %}
  );

  // This class is managing cuda memory and cannot be copied.
  {{ solver.struct_name }}( const {{ solver.struct_name }}&) = delete;
  {{ solver.struct_name }}& operator=( const {{ solver.struct_name }}&) = delete;

  {{ solver.struct_name }}( {{ solver.struct_name }}&&) = default;
  {{ solver.struct_name }}& operator=( {{ solver.struct_name }}&&) = default;

  ~{{ solver.struct_name }}();

  /**
   * Set the solver parameters.
   */
  void set_params(const SolverParams &params);

  /**
   * Run the solver.
   */
  float solve(bool print_progress = false);

  /**
   * Finish the indices.
   *
   * This function has to be called after all indices are set and before the solve function is
   * called.
   */
  void finish_indices();

  /**
   * Get the number of allocated bytes.
   */
  size_t get_allocation_size();

  {% for nodetype in solver.node_types %}
  /**
   * Set the current value for the {{nodetype.__name__}} nodes from the stacked host data.
   *
   * The offset can be used to start writing at a specific index.
   */
  void set_{{nodetype.__name__}}_nodes_from_stacked_host(
      const float* const data, size_t offset, size_t num);

  /**
   * Set the current value for the {{nodetype.__name__}} nodes from the stacked device data.
   *
   * The offset can be used to start writing at a specific index.
   */
  void set_{{nodetype.__name__}}_nodes_from_stacked_device(
      const float* const data, size_t offset, size_t num);

  /**
   * Read the current value for the {{nodetype.__name__}} nodes into the stacked output host data.
   *
   * The offset can be used to start reading from a specific index.
   */
  void get_{{nodetype.__name__}}_nodes_to_stacked_host(
      float* const data, size_t offset, size_t num);

  /**
   * Read the current value for the {{nodetype.__name__}} nodes into the stacked output device data.
   *
   * The offset can be used to start reading from a specific index.
   */
  void get_{{nodetype.__name__}}_nodes_to_stacked_device(
      float* const data, size_t offset, size_t num);

  /**
   * Set the current number of active nodes of type {{nodetype.__name__}}.
   *
   * The value is set during initialization and this function is only needed if you want to
   * change the problem between optimization runs. This is work in progress and can have
   * performance impacts.
   */
  void set_{{nodetype.__name__}}_num(size_t num);

  {% endfor %}
  {% for factor in solver.factors %}
  {% for argname, argtype in factor.node_arg_types.items() %}
  /**
   * Set the indices for the {{argname}} argument for the {{factor.name}} factor from host.
   */
  void set_{{factor.name}}_{{argname}}_indices_from_host(
      const unsigned int* const indices, size_t num);

  /**
   * Set the indices for the {{argname}} argument for the {{factor.name}} factor from device.
   */
  void set_{{factor.name}}_{{argname}}_indices_from_device(
      const unsigned int* const indices, size_t num);

  {% endfor %}
  {% for argname, argtype in factor.const_arg_types.items() %}
  /**
   * Set the values for the {{argname}} consts {{factor.name}} factor from stacked host data.
   *
   * The offset can be used to start writing from a specific index.
   */
  void set_{{factor.name}}_{{argname}}_data_from_stacked_host(
      const float* const data, size_t offset, size_t num);

  /**
   * Set the values for the {{argname}} consts {{factor.name}} factor from stacked device data.
   *
   * The offset can be used to start writing from a specific index.
   */
  void set_{{factor.name}}_{{argname}}_data_from_stacked_device(
      const float* const data, size_t offset, size_t num);

  {% endfor %}
  /**
   * Set the current number of {{factor.name}} factors.
   *
   * The value is set during initialization and this function is only needed if you want to
   * change the problem between optimization runs. This is work in progress and can have
   * performance impacts.
   */
  void set_{{factor.name}}_num(size_t num);

  {% endfor %}

 private:
  SolverParams params_;
  uint8_t* origin_ptr_;
  size_t scratch_inout_size_;
  size_t scratch_sum_size_;
  size_t allocation_size_;

  int solver_iter_;
  int pcg_iter_;

  bool indices_valid_;

  float pcg_r_0_norm2_;
  float pcg_r_kp1_norm2_;

  {% for thing in solver.size_contributors %}
  size_t {{num_key(thing)}};
  size_t {{num_blocks_key(thing)}};
  size_t {{num_max_key(thing)}};
  {% endfor %}

  size_t get_nbytes();
  float linearize_first();
  void linearize();
  float do_res_jac_first();
  void do_res_jac();
  void do_njtr_precond();
  void do_normalize();
  void do_jp();
  void do_jtjp();
  void do_alpha_first();
  void do_alpha();
  void do_update_step_first();
  void do_update_step();
  void do_update_r_first();
  void do_update_r();
  float do_retract_score();
  void do_beta();
  void do_update_p();
  float get_pred_decrease();

  {% for name, datadesc in solver.fields.items() %}
  {{ datadesc.dtype }}* {{ name }};
  {% endfor %}
};

} // namespace caspar
